{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nus_wide_data_util import TwoPartyNusWideDataLoader, retrieve_top_k_labels\n",
    "\n",
    "\n",
    "def get_data_with_multi_classes(data_dir, labels):\n",
    "    print(\"[INFO] target_label: {0}\".format(labels))\n",
    "    data_loader = TwoPartyNusWideDataLoader(data_dir)\n",
    "    image, text, labels = data_loader.get_train_data(target_labels=labels, binary_classification=False)\n",
    "    # image, text, labels = data_loader.get_test_data(target_labels=target_labels, binary_classification=False)\n",
    "    print(\"[INFO] image shape: {}\".format(image.shape))\n",
    "    print(\"[INFO] text shape: {}\".format(text.shape))\n",
    "    print(\"[INFO] labels shape: {}\".format(labels.shape))\n",
    "\n",
    "\n",
    "def get_data_with_binary_classes(data_dir, labels):\n",
    "    data_loader = TwoPartyNusWideDataLoader(data_dir, binary_negative_label=0)\n",
    "    image, text, labels = data_loader.get_train_data(target_labels=labels, binary_classification=True)\n",
    "    # image, text, labels = data_loader.get_test_data(target_labels=target_labels, binary_classification=True)\n",
    "    print(\"[INFO] image shape: {}\".format(image.shape))\n",
    "    print(\"[INFO] text shape: {}\".format(text.shape))\n",
    "    print(\"[INFO] labels shape: {}\".format(labels.shape))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "### Load top 10 labels: \n",
    "\n",
    "'sky', 'clouds', 'person', 'water', 'animal', 'grass', 'buildings', 'window', 'plants', 'lake'\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = \"Input Your Directory Here\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['sky', 'clouds', 'person', 'water', 'animal', 'grass', 'buildings', 'window', 'plants', 'lake']\n"
     ]
    }
   ],
   "source": [
    "top_k_labels = retrieve_top_k_labels(data_dir, top_k=10)\n",
    "print(top_k_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load data with 10 classes using `get_data_with_multi_classes`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[INFO] target_label: ['sky', 'clouds', 'person', 'water', 'animal', 'grass', 'buildings', 'window', 'plants', 'lake']\n",
      "[INFO] load data with labels:['sky', 'clouds', 'person', 'water', 'animal', 'grass', 'buildings', 'window', 'plants', 'lake'] for multi-classification.\n",
      "[INFO] load rows with valid label (58619, 10)\n",
      "[INFO] load image feature (Train_Normalized_CM55.dat) with (225) dimension.\n",
      "[INFO] load image feature (Train_Normalized_EDH.dat) with (73) dimension.\n",
      "[INFO] load image feature (Train_Normalized_CH.dat) with (64) dimension.\n",
      "[INFO] load image feature (Train_Normalized_WT.dat) with (128) dimension.\n",
      "[INFO] load image feature (Train_Normalized_CORR.dat) with (144) dimension.\n",
      "[INFO] load all image feature with shape: (161789, 634)\n",
      "[INFO] load all text feature (Train_Tags1k.dat) with shape: (161789, 1000).\n",
      "[INFO] image shape: (58619, 634)\n",
      "[INFO] text shape: (58619, 1000)\n",
      "[INFO] labels shape: (58619, 10)\n"
     ]
    }
   ],
   "source": [
    "target_labels = ['sky', 'clouds', 'person', 'water', 'animal', 'grass', 'buildings', 'window', 'plants', 'lake']\n",
    "get_data_with_multi_classes(data_dir, target_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load data with `target_labels = None` using `get_data_with_multi_classes`, which will load data with top 5 classes as default."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[INFO] target_label: None\n",
      "[INFO] load data with labels:['sky', 'clouds', 'person', 'water', 'animal'] for multi-classification.\n",
      "[INFO] load rows with valid label (60157, 5)\n",
      "[INFO] load image feature (Train_Normalized_CM55.dat) with (225) dimension.\n",
      "[INFO] load image feature (Train_Normalized_EDH.dat) with (73) dimension.\n",
      "[INFO] load image feature (Train_Normalized_CH.dat) with (64) dimension.\n",
      "[INFO] load image feature (Train_Normalized_WT.dat) with (128) dimension.\n",
      "[INFO] load image feature (Train_Normalized_CORR.dat) with (144) dimension.\n",
      "[INFO] load all image features with shape: (161789, 634)\n",
      "[INFO] load all text features (Train_Tags1k.dat) with shape: (161789, 1000).\n",
      "[INFO] image shape: (60157, 634)\n",
      "[INFO] text shape: (60157, 1000)\n",
      "[INFO] labels shape: (60157, 5)\n"
     ]
    }
   ],
   "source": [
    "target_labels = None\n",
    "get_data_with_multi_classes(data_dir, target_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load data with 2 classes using `get_data_with_multi_classes` will throw an exception"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[INFO] target_label: ['sky', 'clouds']\n",
      "Exception occur:Multi-classification does not support the number of classes smaller than or equal to 2\n"
     ]
    }
   ],
   "source": [
    "target_labels = ['sky', 'clouds']\n",
    "try:\n",
    "    get_data_with_multi_classes(data_dir, target_labels)\n",
    "except Exception as exc:\n",
    "    print(f\"Exception occur:{exc}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load data with 2 classes using `get_data_with_binary_classes` "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[INFO] load data with labels:['sky'] vs ['person'] for binary-classification.\n",
      "[INFO] load rows with valid label (68278, 2)\n",
      "[INFO] load image feature (Train_Normalized_CM55.dat) with (225) dimension.\n",
      "[INFO] load image feature (Train_Normalized_EDH.dat) with (73) dimension.\n",
      "[INFO] load image feature (Train_Normalized_CH.dat) with (64) dimension.\n",
      "[INFO] load image feature (Train_Normalized_WT.dat) with (128) dimension.\n",
      "[INFO] load image feature (Train_Normalized_CORR.dat) with (144) dimension.\n",
      "[INFO] load all image features with shape: (161789, 634)\n",
      "[INFO] load all text features (Train_Tags1k.dat) with shape: (161789, 1000).\n",
      "[INFO] # of positive samples: 40812, # of negative samples: 27466\n",
      "[INFO] image shape: (68278, 634)\n",
      "[INFO] text shape: (68278, 1000)\n",
      "[INFO] labels shape: (68278,)\n"
     ]
    }
   ],
   "source": [
    "target_labels = [\"sky\", \"person\"]\n",
    "get_data_with_binary_classes(data_dir, target_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load data with only one target label using `get_data_with_binary_classes` \n",
    "\n",
    "**When using `get_data_with_binary_classes`, if only one target label is set, then this target label will be treated as the positive labels and all other top 5 labels as negative labels, which are top K labels, will be treated as negative labels.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[INFO] load data with labels:['water'] vs ['sky', 'clouds', 'person', 'animal'] for binary-classification.\n",
      "[INFO] load rows with valid label (60157, 5)\n",
      "[INFO] load image feature (Train_Normalized_CM55.dat) with (225) dimension.\n",
      "[INFO] load image feature (Train_Normalized_EDH.dat) with (73) dimension.\n",
      "[INFO] load image feature (Train_Normalized_CH.dat) with (64) dimension.\n",
      "[INFO] load image feature (Train_Normalized_WT.dat) with (128) dimension.\n",
      "[INFO] load image feature (Train_Normalized_CORR.dat) with (144) dimension.\n",
      "[INFO] load all image features with shape: (161789, 634)\n",
      "[INFO] load all text features (Train_Tags1k.dat) with shape: (161789, 1000).\n",
      "[INFO] # of positive samples: 5454, # of negative samples: 54703\n",
      "[INFO] image shape: (60157, 634)\n",
      "[INFO] text shape: (60157, 1000)\n",
      "[INFO] labels shape: (60157,)\n"
     ]
    }
   ],
   "source": [
    "target_labels = [\"water\"]\n",
    "get_data_with_binary_classes(data_dir, target_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[INFO] load data with labels:['plants'] vs ['sky', 'clouds', 'person', 'water'] for binary-classification.\n",
      "[INFO] load rows with valid label (51702, 5)\n",
      "[INFO] load image feature (Train_Normalized_CM55.dat) with (225) dimension.\n",
      "[INFO] load image feature (Train_Normalized_EDH.dat) with (73) dimension.\n",
      "[INFO] load image feature (Train_Normalized_CH.dat) with (64) dimension.\n",
      "[INFO] load image feature (Train_Normalized_WT.dat) with (128) dimension.\n",
      "[INFO] load image feature (Train_Normalized_CORR.dat) with (144) dimension.\n",
      "[INFO] load all image features with shape: (161789, 634)\n",
      "[INFO] load all text features (Train_Tags1k.dat) with shape: (161789, 1000).\n",
      "[INFO] # of positive samples: 4835, # of negative samples: 46867\n",
      "[INFO] image shape: (51702, 634)\n",
      "[INFO] text shape: (51702, 1000)\n",
      "[INFO] labels shape: (51702,)\n"
     ]
    }
   ],
   "source": [
    "target_labels = [\"plants\"]\n",
    "get_data_with_binary_classes(data_dir, target_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load data with more than two classes using `get_data_with_binary_classes`, which will throw an exception"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Exception occur:binary classification does not support 3 # of labels, which are ['sky', 'person', 'plants'].\n"
     ]
    }
   ],
   "source": [
    "target_labels = [\"sky\", \"person\", \"plants\"]\n",
    "try:\n",
    "    get_data_with_binary_classes(data_dir, target_labels)\n",
    "except Exception as exc:\n",
    "    print(f\"Exception occur:{exc}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
